{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![open.svg](images/open.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLP实现股价预测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 逻辑回归模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用于解决分类问题的一种模型。根据数据特征或属性，计算其归属于某一类别的概率P(x)，根据概率数值判断其所属类别。\n",
    "- 主要应用场景：二分类问题"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 逻辑回归数学表达式\n",
    "\n",
    "![logistic_regression formula](images/06_logistic_regression_formula.png)\n",
    "\n",
    "其中，y为类别结果，P为概率，x为特征值，a、b为常量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型评估回顾\n",
    "\n",
    "- 目的：通过模型评估对比模型表现、确定合适的模型参数（组）\n",
    "- 方法：计算测试数据集预测准确率以评估模型表现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 皮马印第安人糖尿病数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基于数据集中包括的某些诊断测量来诊断性地预测患者是否患有糖尿病\n",
    "\n",
    "输入变量包括：独立变量包括患者的怀孕次数，葡萄糖量，血压，皮褶厚度，体重指数，胰岛素水平，糖尿病谱系功能，年龄\n",
    "\n",
    "输出结果：是否患有糖尿病\n",
    "\n",
    "数据来源：[Pima Indians Diabetes dataset](https://www.kaggle.com/uciml/pima-indians-diabetes-database)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**任务:** 通过怀孕次数、胰岛素水平、体重指数、年龄四个特征预测是否患有糖尿病"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 典型的逻辑回归概率分布曲线\n",
    "\n",
    "![unrolled_RNN](images/unrolled_RNN.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 基于t时刻及之前数据预测t+1时刻结果（滑动窗口预测）\n",
    "\n",
    "![sliding_window_time_series.svg](images/sliding_window_time_series.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train LSTM on five years of Google \n",
    "# Supervised Deep Learning\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "# Importing the training set\n",
    "training_set_ori = pd.read_csv(\"transfer_data1.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>y2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-5.0</td>\n",
       "      <td>25.00</td>\n",
       "      <td>26.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-4.9</td>\n",
       "      <td>24.01</td>\n",
       "      <td>25.21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-4.8</td>\n",
       "      <td>23.04</td>\n",
       "      <td>24.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-4.7</td>\n",
       "      <td>22.09</td>\n",
       "      <td>23.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-4.6</td>\n",
       "      <td>21.16</td>\n",
       "      <td>22.96</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     x      y     y2\n",
       "0 -5.0  25.00  26.00\n",
       "1 -4.9  24.01  25.21\n",
       "2 -4.8  23.04  24.44\n",
       "3 -4.7  22.09  23.69\n",
       "4 -4.6  21.16  22.96"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_set_ori.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xmc1vP+//HHqz1K2iy/tAkllTCSg7JFOllC4hBKEpWiHLKcODi2lCIlWkgoX0JxSFkjNCWFUqQ0isrSvk7v3x/va5yRqZlm5nO9r+V5v93mNnNd85m5ntfQ9breuznnEBGR9FUidAAREQlLhUBEJM2pEIiIpDkVAhGRNKdCICKS5lQIRETSnAqBSAGY2RIzOz2i3/2VmZ0cxe8WKQgVAklqZnaimX1sZmvM7Fcz+8jMjo1970ozmx4gkzOzDWa23sx+NLOBZlZyV9c7545wzr0Xx4gif1IqdACRwjKzfYDJwLXABKAMcBKwJWSumCOdc9+aWQPgPWAhMDz3BWZWyjm3PUQ4kdzUIpBkdhiAc+5551y2c26Tc26Kc26umR2Of+E9PvbO/HcAM6tkZs+Y2SozW2pmt5vZH/8OzOxqM5tvZuvM7GszO3rnBzWzBmb2vZldnF9A59wC4EOgUexnl5jZzWY2F9hgZqVydzuZWUkzu9XMvotlmGVmNXM97tuxls83ZnZRkf+CIqgQSHJbCGSb2dNmdpaZVc75hnNuPtANmOGcq+Cc2zf2rUeBSsDBQEvgcqATgJm1B+6M3bcPcA7wS+4HjBWGKUBP59wL+QU0s4b4Vsrnue6+BPg7sG8eLYIbY99vE8vQGdhoZnsDbwPPAfvFrnnczI7IL4NIflQIJGk559YCJwIOeBJYZWavmdn+eV0f66fvAPRzzq1zzi0BHgY6xi7pAjzonJvpvG+dc0tz/YqTgNeAK5xzk/OJN9vMfgMmAU8Bo3N9b4hzbplzblMeP9cFuN05900swxfOuV+AtsAS59xo59x259xs4CXgwnxyiORLYwSS1GLv/K8E33UCPAs8gn/HvLNq+HGE3C/uS4Easa9rAt/t5uG6Ae87594tQLSjnXPf7uJ7y3bzc7vKUBs4LqeLK6YUMLYAWUR2Sy0CSRmx/vgxxPrj8S2F3FYD2/AvqjlqAT/Gvl4G1NvNQ3QDapnZoKJG3c33dpVhGb4I7Zvro4Jz7toiZhFRIZDkFRs87WNmB8Vu18S3BD6JXfIzcJCZlQFwzmXjZxfda2YVzaw2vk/+2dj1TwF9zewY8w6JXZNjHdAaaGFm90f0tJ4C7jazQ2MZmphZVfzsqMPMrKOZlY59HBsbFBcpEhUCSWbrgOOAT81sA74AfAn0iX3/HeAr4CczWx27ryewAVgMTMcPvo4CcM69CNwbu28d8ApQJfcDOud+B1oBZ5nZ3RE8p4H4YjUFWAuMBMo759YBZwAXA8uBn4AHgLIRZJA0YzqYRkQkvalFICKS5lQIRETSnAqBiEiaUyEQEUlzSbGgrFq1aq5OnTqhY4iIJJVZs2atds5Vz++6pCgEderUITMzM3QMEZGkYmZL879KXUMiImlPhUBEJM2pEIiIpDkVAhGRNKdCICKS5iIrBGZW08zejR3795WZ9Yrdf2fsQO85sY82UWUQEZH8RTl9dDvQxzk328wqArPM7O3Y9wY55wZE+NgiIlJAkbUInHMrYsfpEdtCdz7/OwkqPt5+G+6Patt4EZEIbdgAvXvDd7s7NK94xGWMwMzqAEcBn8bu6mFmc81sVO4Dx3f6ma5mlmlmmatWrSrcA7/9NtxxB/z8c+F+XkQklBdfhMGDYfnyyB8q8kJgZhXwh2z3jh02Pgx/FF9TYAX+8PC/cM6NcM5lOOcyqlfPd4V03q66CrZvh2eeKdzPi4iEMnIkHHYYnHhi5A8VaSEws9L4IjDOOfcygHPuZ+dctnNuB/Ak0CyyAPXr+z/iU0+BDuARkWSxYAFMn+7fzJpF/nBRzhoy/DF7851zA3Pdf2Cuy9rhjxaMTpcusHCh/6OKiCSDkSOhVCm4/PK4PFyULYITgI7AqTtNFX3QzOaZ2VzgFOCGCDPAhRfCPvv4P6yISKLbutV3Z599NhxwQFweMrLpo8656UBebZo3onrMPO29N1xyif/DDh4MlSrF9eFFRPbI5MmwcqXvFoqT9FhZ3KULbNoEzz8fOomIyO499RTUqAFnnhm3h0yPQnDMMdCkif8Di4gkqqwseOstuPJKP0YQJ+lRCMx8q2DWLPj889BpRETyNmoU7NgBnTvH9WHToxAAXHYZlCsHTz4ZOomIyF9lZ/tei1at4OCD4/rQ6VMIKleG9u1h3Di/dFtEJJFMmQLLlsHVV8f9odOnEAB07Qpr18KECaGTiIj82YgRUL06nHtu3B86vQrBCSfA4Yf7P7iISKJYsQImTYJOnaBMmbg/fHoVAjPf7PrkE5g3L3QaERFv9Gg/RtClS5CHT69CANCxo6+4GjQWkUSwY4cfJD7lFDj00CAR0q8QVKsGF1wAY8fCxo2h04hIups2Db7/3o9hBpJ+hQD8H/z33/1+3yIiIT3xhH+D2q5dsAjpWQhatvRbVA8fHjqJiKSz5cvhlVf8IHHZssFipGchMINu3fyg8RdfhE4jIulq1Cg/SBywWwjStRCA3+e7XDnfLBMRibfsbD+VvVUrOOSQoFHStxBUqQIXXQTPPgvr14dOIyLp5s03/Uria64JnSSNCwH47qF167Q9tYjE3/Dh/uCZc84JnSTNC0Hz5n576mHDdKaxiMTPDz/AG2/4BWSlS4dOk+aFIGfQ+PPPYebM0GlEJF3kbHMTaCXxztK7EIDfnrpCBXj88dBJRCQdbN3qdzZo2xZq1w6dBlAhgIoV/QyiF16AX34JnUZEUt3LL/szia+9NnSSP6gQgP8PsmWL3/hJRCRKjz/uD54544zQSf6gQgDQqBG0aOEHjXfsCJ1GRFLVvHnw4Yf+zWeJxHn5TZwkoV13HSxe7E8JEhGJwrBhfiuJTp1CJ/kTFYIc7drB/vtr0FhEorF2rd/1+OKLoWrV0Gn+RIUgR5ky/tCayZNhyZLQaUQk1Ywd63cxSKBB4hwqBLl17er77YYNC51ERFKJczB0KGRkQLNmodP8hQpBbjVrwnnn+dOCNm0KnUZEUsU778D8+dCjh1/ImmBUCHbWowf8+qtfVyAiUhwee8wfPtOhQ+gkeVIh2FnLlnDEEfDoo9p/SESKbulSeO01PwZZrlzoNHlSIdiZmW8VfP65P7hGRKQock5C7NYtbI7dUCHIy2WXQaVKvlUgIlJYmzb5fYXOPRdq1QqdZpciKwRmVtPM3jWz+Wb2lZn1it1fxczeNrNFsc+Vo8pQaBUq+AUfL74IK1aETiMiyWr8eL+HWc+eoZPsVpQtgu1AH+fc4UBzoLuZNQRuAaY55w4FpsVuJ57u3f1RcjrgXkQKwzkYPNiPOZ58cug0uxVZIXDOrXDOzY59vQ6YD9QAzgWejl32NHBeVBmK5JBDoE0bXwi2bAmdRkSSzYcfwpw5cP31CTllNLe4jBGYWR3gKOBTYH/n3ArwxQLYbxc/09XMMs0sc9WqVfGI+Ve9evntYsePD/P4IpK8hgzxZ6NfdlnoJPmKvBCYWQXgJaC3c25tQX/OOTfCOZfhnMuoXr16dAF35/TToWFD37zTVFIRKailS2HiRD9ldK+9QqfJV6SFwMxK44vAOOfcy7G7fzazA2PfPxBYGWWGIjHzzbrZs+Gjj0KnEZFkMXSof/247rrQSQokyllDBowE5jvnBub61mvAFbGvrwBejSpDsbjsMth3X98qEBHJz4YNfspou3YJPWU0tyhbBCcAHYFTzWxO7KMNcD/QyswWAa1itxPX3nv75t3EifDDD6HTiEiiGzsWfv/djzEmCXNJ0PedkZHhMjMzwwX44Qd/tNwNN8BDD4XLISKJbccOP65YoQLMnBl8tpCZzXLOZeR3nVYWF0StWnDBBb65t3596DQikqjefBO++ca/aUzwKaO5qRAU1I03wpo1OuBeRHZt0CCoUQPatw+dZI+oEBTUccfB8cf7QePs7NBpRCTRzJsHU6f6TSvLlAmdZo+oEOyJG26A776DSZNCJxGRRDNokF8z0LVr6CR7TIVgT7RrB7Vr+//gIiI5fv4Zxo2DK6/0q4mTjArBnihVyi8w++ADmDUrdBoRSRTDhsHWrUk1ZTQ3FYI91aUL7LMPPPxw6CQikgg2bvQric8+Gw47LHSaQlEh2FP77OP7ACdM8PuJiEh6e+YZWL0abropdJJCUyEojJxtZbXthEh6y86GgQOhWTM48cTQaQpNhaAwataEiy/2C8x+/z10GhEJZdIkWLQI+vZNqgVkO1MhKKw+ffwq4xEjQicRkVAGDIC6df2MwiSmQlBYTZv68woGD/azBUQkvcyY4benv+EGP6MwiakQFEXfvrB8OTz3XOgkIhJvAwb4Leo7dQqdpMhUCIrijDOgSRO/I+mOHaHTiEi8fPON35q+e3e/02iSUyEoCjP45z/h66/h9ddDpxGReBkwAMqW9TMIU4AKQVFddJHfduLBB0MnEZF4WLHCrx3o1An22y90mmKhQlBUpUv7LaqnT4ePPw6dRkSiNngwbN/uZw6mCBWC4nDVVVC1qloFIqlu7Vq/r9CFF0K9eqHTFBsVguKw995+D/JXX4X580OnEZGoPPGELwb//GfoJMVKhaC49Ojh9yJ/4IHQSUQkCps3++0kTj8djjkmdJpipUJQXKpV85vRPfssLFkSOo2IFLcxY+Cnn+DWW0MnKXYqBMWpTx8oUcJPLROR1LF9ux8DbN4cTj45dJpip0JQnA46CC6/HEaO9CcWiUhqGD8evv8e+vVL6s3ldkWFoLjdfLPfe+iRR0InEZHisGMH3HcfNGoEbduGThMJFYLiduih0L69P7FIW1SLJL9Jk+Crr3xroERqvmSm5rMKrV8/WLcOHnssdBIRKQrn4N574eCD/S4CKUqFIApHHumbkIMG+YIgIslpyhSYORNuuSXpt5reHRWCqNxxB/z6KwwfHjqJiBSGc3D33f5EwiuuCJ0mUioEUWnWzG9TPWAAbNwYOo2I7Kn33/cHz9x8M5QpEzpNpFQIonT77bBypT/bWESSy913wwEHQOfOoZNELrJCYGajzGylmX2Z6747zexHM5sT+2gT1eMnhJNOgpYt/UKUzZtDpxGRgvr4Y3jnHbjpJihfPnSayEXZIhgDtM7j/kHOuaaxjzcifPzEcMcd/jjL0aNDJxGRgrr7br9tzDXXhE4SF5EVAufcB8CvUf3+pHHqqXDCCX5BypYtodOISH4+/RTefNOfSb733qHTxEWIMYIeZjY31nVUeVcXmVlXM8s0s8xVq1bFM1/xMoM774Rly9QqEEkGd93lzxfp3j10kriJdyEYBtQDmgIrgId3daFzboRzLsM5l1G9evV45YvGaaf5VsF//qNWgUgi+/RT+O9//dhAChxKX1BxLQTOuZ+dc9nOuR3Ak0CzeD5+MGoViCSHNGwNQJwLgZkdmOtmO+DLXV2bctQqEElsadoagGinjz4PzADqm1mWmV0FPGhm88xsLnAKcENUj59wcrcKRo0KnUZEdnbnnWnZGgAw51zoDPnKyMhwmZmZoWMUnXPQogUsXgzffpsW85NFksLHH/sW+wMPpNR5xGY2yzmXkd91WlkcT2Z+fvLy5f4QbBFJDHfcAfvv788eT0MqBPF28sl+bcF998GGDaHTiMi77/pVxP36wV57hU4ThApBCHff7fcg0nkFImE551sDNWqkzSrivKgQhPC3v8FZZ/k9iNauDZ1GJH299ZbfYfT226FcudBpgtmjQmBm6bHeOh7+/W9/XsGgQaGTiKSnnNZAnTppscPo7hSoEJjZ38zsa2B+7PaRZvZ4pMlSXUYGtGsHDz8Mq1eHTiOSfl5+GTIzoX//lD9vID8FbREMAs4EfgFwzn0BtIgqVNq45x4/YHzffaGTiKSX7dt9d9Dhh0PHjqHTBFfgriHn3LKd7sou5izpp2FD/z/h0KF+oZmIxMfYsbBggX8zVrJk6DTBFbQQLDOzvwHOzMqYWV9i3URSRHfeCTt2+JlEIhK9LVv8v7tjj/Xds1LgQtAN6A7UALLwu4em3zrsKNSpA926+W0nFi4MnUYk9Q0fDj/84Pf9MgudJiFoi4lE8PPPUK8etGkDEyaETiOSutauhUMOgcaNYdq00GkiV6xbTJjZ02a2b67blc1MO6cVl/33hz594MUX4bPPQqcRSV0PPQSrVsH994dOklAK2jXUxDn3e84N59xvwFHRREpTfftC9ep+w6skaKWJJJ0VK2DgQLjoIj8+IH8oaCEokftYSTOrApSKJlKaqlgR/vUveP99vye6iBSvu+6CrVvh3ntDJ0k4BS0EDwMfm9ndZnY38DHwYHSx0lTXrn6s4OabIVuzc0WKzTffwFNP+YkZhxwSOk3CKVAhcM49A1wA/AysBM53zo2NMlhaKlPGz2T48ks/z1lEike/fv78jzvuCJ0kIe22EJjZPrHPVYCfgOeAccBPsfukuLVv7/svb78dNm4MnUYk+X34IUyc6Mff9tsvdJqElF+L4LnY51lAZq6PnNtS3Mz8gNaPP/p9iESk8Hbs8DPyatTwnyVPux3wdc61NTMDWjrnfohTJjnxRDj/fH9s3tVXwwEHhE4kkpzGj4eZM2HMmLQ9dKYg8h0jcH7F2cQ4ZJHcHnjAL4X/179CJxFJTps3+7GBpk21sVw+Cjpr6BMz08TbeDrkEOjeHUaO9IPHIrJnBg+GpUt9F2sJncG1OwXaYiJ2FkF9YAmwATB8Y6FJpOliUn6LiV359Vc/nfTYY/1JStoXRaRgVq6EQw+FFi1g0qTQaYIp6BYTBV0UdlYR80hhVKnid0ns3Rtefx3atg2dSCQ55My6GzAgdJKkkN/00XJm1hu4CWgN/OicW5rzEZeE6e6666BBA7jxRr8qUkR2b84cv3isZ0+oXz90mqSQX8fZ00AGMA/fKtB8xngrXdpPJ120CB59NHQakcTmHPTqBVWraqLFHsivEDR0zl3mnHsCuBA4KQ6ZZGdnneU//v1v3/cpInl76SX44AN/0NO+++Z/vQD5F4JtOV8457ZHnEV2Z+BA3+d5++2hk4gkpk2b4Kab/FkDXbqETpNU8isER5rZ2tjHOqBJztdmtjYeASWmQQPf5/nUUzBrVug0IonnoYdgyRI/bbSUNkfeEzqhLJmsWeMHv+rWhY8+0txokRxLlsDhh8M55/jVxAIU8wllkiAqVfIrjj/5BJ55JnQakcTRp49/Y6TpooWiQpBsOnaE44/3ZxasWRM6jUh4b78NL78Mt90GNWuGTpOUIisEZjbKzFaa2Ze57qtiZm+b2aLY58q7+x2ShxIl4LHH/Lmr/fuHTiMS1tatcP31fksW7S5aaFG2CMbgF6HldgswzTl3KDAtdlv21NFHwzXX+HUFc+aETiMSzsCBsGCBHyAuWzZ0mqQVWSFwzn0A/LrT3efiF6kR+3xeVI+f8v7zH79o5rrr/J7rIulmyRK/tqZdO2jTJnSapBbvMYL9nXMrAGKfd3lckJl1NbNMM8tctWpV3AImjcqV/cDYjBkwalToNCLx16uX7yodPDh0kqSXsIPFzrkRzrkM51xG9erVQ8dJTB07wkkn+YHj1atDpxGJn9de8x/9+2uAuBjEuxD8bGYHAsQ+a7+EojCDxx+HtWt9MRBJBxs2+AHiI47wO/NKkcW7ELwGXBH7+grg1Tg/fupp1MjvTDpqFLz/fug0ItHr398fODNsmN+UUYosspXFZvY8cDJQDfgZ6A+8AkwAagE/AO2dczsPKP+FVhbnY+NGXxDKlPGziMqVC51IJBqzZ/uDmrp0gSeeCJ0m4RV0ZbG2mEgVU6bAmWf6rXfvuit0GpHit307NG8OWVkwf76fMCG7pS0m0s0ZZ8Cll8J998HXX4dOI1L8Hn3Ub7g4ZIiKQDFTIUglAwdCxYrQtavWFkhqWbIE7rgD/v53aN8+dJqUo0KQSvbbzxeDjz7ys4lEUoFz/s2NGQwd6j9LsVIhSDWXXw6tW8Mtt/h3USLJbswYv7HcAw9A7dqh06QkFYJUY+ZnU5jB1Vf7d1MiyWrFCj89+qSToFu30GlSlgpBKqpVy797mjoVRo8OnUakcJzze2lt3uxP5tNBTJHRXzZVdesGLVr4d1NZWaHTiOy5F16AV17x06EPOyx0mpSmQpCqSpTwq423bfOLb9RFJMlkxQro3h2OO86/mZFIqRCksnr14MEH4a23fNNaJBnkzBLatAmefloH0ceBCkGqu/ZaOPVU/65Ks4gkGTz9NEye7M/cqF8/dJq0oEKQ6nK6iMygc2ctNJPEtmyZP2fgpJP8Z4kLFYJ0ULs2DBoE774LjzwSOo1I3nbsgCuugOxsP9tNs4TiRn/pdNG5M5x3HvTrB/PmhU4j8lc5b1YGD/bjWxI3KgTpwgxGjPCbdV16KWzZEjqRyP/MnQu33urfrHTuHDpN2lEhSCfVq8PIkb5FcPvtodOIeJs3w2WX+TcpI0ZoL6EAVAjSzd//7mcSDRjg928RCe2WW/ybk5Ej/ZsViTsVgnQ0YAA0bOg3qFupY6MloMmT/ZjA9df7NykShApBOtprL798/7ff4MorNaVUwli+HDp1gqZN/cJHCUaFIF01buzPLvjvf/2JTyLxlJ0NHTv687affx7Klg2dKK2pEKSza6/1szT++U+YOTN0Gkkn990H77zj34Q0aBA6TdpTIUhnZn6A7v/9P7joIt9VJBK1d9+F/v39NGZNFU0IKgTprkoVGD8efvzR99dql1KJ0k8/wT/+4beVHj5cU0UThAqB+K1+H3wQXn3Vr+4UiUJ2ti8Ca9bAiy9ChQqhE0mMCoF4vXpBu3Zw880wfXroNJKK/vUv3y00dCg0ahQ6jeSiQiCemd+ltE4daN/eHwwiUlxefdVvK33VVb4LUhKKCoH8z777wsSJsHatLwZbt4ZOJKlg4UK/eDEjAx57LHQayYMKgfxZo0a+ZfDRR9C3b+g0kuzWr/ddjmXKwEsvQblyoRNJHnQGnPxVhw7w6ad+4Pioo9SUl8LZscO3BBYsgClToFat0IlkF9QikLw9+CCcfjp06wYffxw6jSSjf//bdzUOGACnnRY6jeyGCoHkrVQpv76gZk04/3x/hKBIQb30Etx1lz9xrHfv0GkkH0EKgZktMbN5ZjbHzDJDZJACqFLFz/bYuNFvRbFhQ+hEkgzmzPFdQs2ba9FYkgjZIjjFOdfUOZcRMIPk54gj4Lnn4PPP/SZh2qlUdufHH6FtW3/IzMsva3A4SahrSPLXtq3fqXTiRH+IiEhe1q+Hs8/2K4dffx0OPDB0IimgULOGHDDFzBzwhHNuxM4XmFlXoCtALc02CK9XL1i0CB56CA49FK6+OnQiSSTZ2X4TuS++gEmT4MgjQyeSPRCqEJzgnFtuZvsBb5vZAufcB7kviBWHEQAZGRnaCS00M3+S1Pff++2rDzoIzjordCpJBM7BDTfAa6/Bo49CmzahE8keCtI15JxbHvu8EpgINAuRQ/ZQzkyiJk3gwgt1hoF4Dz7oC8ANN0CPHqHTSCHEvRCY2d5mVjHna+AM4Mt455BCqlgR3ngD9tvPnzG7aFHoRBLS2LF+3Ojii/16AUlKIVoE+wPTzewL4DPgdefcmwFySGEdcAC89ZafQdS6tTaoS1dvveUPljnlFBgzBkpo7kmyivsYgXNuMaCRpGR32GF+Zshpp8EZZ8D77/t1B5Iepk/3ewg1auRnk+nM4aSmEi6Fd9xx8MorfnfJNm1g3brQiSQePv/cdwvWrOlbBZUqhU4kRaRCIEVz+ul+ADkz068+3rw5dCKJ0oIFcOaZfsvyqVP9WJEkPRUCKbrzzoPRo/3pU+3awZYtoRNJFBYtglNP9WMBU6f6FoGkBBUCKR4dO8KIEfDmm3DBBSoGqebbb/2g8PbtMG2aX1QoKUOFQIpPly5+k7HXX9cJZ6lk8WJfBLZs8UXgiCNCJ5JipkIgxeuaa/zh5JMm+e2rNWaQ3BYuhBYt/A60U6dC48ahE0kEVAik+F13nW8ZvPGG37BO21cnpy+/9EVg2zY//qP9g1KWCoFE45pr/CKjd9/1s0zWrAmdSPbErFnQsiWULOnXiDRpEjqRREiFQKJz+eXwwgv+/OOWLeGnn0InkoKYNg1OPtlvJ/LBB9CgQehEEjEVAolW+/Z+vGDRIjjhBD/7RBLXhAl+cWCdOvDRR1CvXuhEEgcqBBK91q3hnXd899AJJ/jFZ5J4Hn3Ubx537LG+JVCjRuhEEicqBBIfxx3n32GWL+8HIF95JXQiyZGd7Q+Yv/56f8LYlCn+qElJGyoEEj/16/vxgsaN/dTSgQP9oSYSzvr1fjX44MG+GLz8Muy1V+hUEmcqBBJf++/vZxKdfz706eOPvNQq5DCWLIGTTvILAIcOhUGD/CwhSTsqBBJ/e+3lByVvuw1GjvQzVJYvD50qvbzzDmRk+GLw+ut+7YekLRUCCaNECbjnHnjxRZg3z78offxx6FSpzzn/zv+MM3zr7LPP/GC+pDUVAgnrwgthxgw/iNyypT/uUOMG0fj9d78h4I03+kHhTz7R5nECqBBIImjcGGbPhnPPhZtu8p9/+SV0qtSSmQlHH+3XdDz8sB8UrlgxdCpJECoEkhgqVfLdREOG+K2smzSBt98OnSr5ZWfDf/4Dxx/vt5D+8EPfIjALnUwSiAqBJA4z6NnTd1lUquT7sXv3hk2bQidLTt9/77vbbrvNz9KaMweaNw+dShKQCoEknqOP9pue9ezp57c3beoPS5eC2bHDrxJu3NgPxI8d6/d8qlIldDJJUKVCByisbdu2kZWVxWbtd5/0ypUrx0EHHUTp0qX/d2f58r6b6OyzoWtXP9+9e3e47z71be/OggX+gKCPPvK7vo4YAbVqhU4lCS5pC0FWVhYVK1akTp06mPo7k5Zzjl9++YWsrCzq1q371wtatfLvam+7zb/LnTjRD3Z26KB+7tw2bPDTcR9+GCpUgKef9seH6m8kBZC0XUObN2+matXJmJ96AAANFElEQVSqKgJJzsyoWrXq7lt2FSr4LqIZM+CAA+CSS+C003yBSHfO+cV5DRrA/ffDP/4B8+f7LcD1b0MKKGkLAaAikCIK/N/xuOP8Aqhhw/zA55FHQufOkJUVbcBE9eGHfjZQhw5QrZofRxkzxi8UE9kDSV0IJA2VLAnduvlzDW68EcaN84uibroJVq4MnS4+Zs2Cc87xu7hmZfltOjIz/RbfIoWgQlAEJUuWpGnTpjRq1Iizzz6b33//vdC/q06dOqxevfov948aNYrGjRvTpEkTGjVqxKuvvgrAmDFjWF7I/XnGjBlDjx498r2mevXqNG3alIYNG/Lkk0/meV1mZibXX399oXIUSZUqfhXyN9/41ckDB/rDVPr0gRUr4p8nHj77zJ8BnZHhWwP33usPl+/cWZvFSZGoEBRB+fLlmTNnDl9++SVVqlRh6NChxfr7s7KyuPfee5k+fTpz587lk08+oUns7NiiFIKC6tChA3PmzOG9997j1ltv5eeff/7T97dv305GRgZDhgyJNMdu1anjp0d+/bU/DW3wYKhdG664wncfJbvsbL8K+KSTfNfYjBm+ACxdCrfeqi2jpVgk7ayhP+ndu/j/0TdtCo88UuDLjz/+eObOnfvH7YceeogJEyawZcsW2rVrx1133QXAeeedx7Jly9i8eTO9evWia9euu/ydK1eupGLFilSoUAGAChUqUKFCBf7v//6PzMxMLr30UsqXL8+MGTP4+OOP6du3L9u3b+fYY49l2LBhlC1blpkzZ9KrVy82bNhA2bJlmTZt2p8e4/XXX+eee+5h0qRJVKtWLc8c++23H/Xq1WPp0qUMGzaM5cuXs2TJEqpVq0bXrl0ZMGAAkydPZv369fTs2ZPMzEzMjP79+3PBBRcwZcoU+vfvz5YtW6hXrx6jR4/+4zkVm/r1/UyZf/3LF4NRo+CZZ/wLaJcuvtWQTC+aWVm+v3/kSL9DaN26frO4q67S9FkpdmoRFIPs7GymTZvGOeecA8CUKVNYtGgRn332GXPmzGHWrFl88MEHgO/qmTVrFpmZmQwZMoRfdrOnzpFHHsn+++9P3bp16dSpE5MmTQLgwgsvJCMjg3HjxjFnzhzMjCuvvJLx48czb948tm/fzrBhw9i6dSsdOnRg8ODBfPHFF0ydOpXy5cv/8fsnTpzI/fffzxtvvLHLIgCwePFiFi9ezCGHHALArFmzePXVV3nuuef+dN3dd99NpUqVmDdvHnPnzuXUU09l9erV3HPPPUydOpXZs2eTkZHBwIEDC/eHLoh69fz6g6wseOgh+Okn3zo48EC/HmHaNL/VQiJas8a3btq08a2aO+6Agw+Gl17yZz737q0iIJFIjRbBHrxzL06bNm2iadOmLFmyhGOOOYZWrVoBvhBMmTKFo446CoD169ezaNEiWrRowZAhQ5g4cSIAy5YtY9GiRVStWjXP31+yZEnefPNNZs6cybRp07jhhhuYNWsWd95555+u++abb6hbty6HHXYYAFdccQVDhw7ltNNO48ADD+TYY48FYJ999vnjZ959910yMzOZMmXKn+7Pbfz48UyfPp2yZcvyxBNPUCW2MvWcc875U0HJMXXqVF544YU/bleuXJnJkyfz9ddfc0JsIHPr1q0cf/zxu//DFod994W+ff2YwYcfwlNPwXPPwZNPwn77wXnn+Rfc007z01NDycqC//4XJk+Gt97yh/TUrAn9+vm+/4MPDpdN0kaQQmBmrYHBQEngKefc/SFyFFXOGMGaNWto27YtQ4cO5frrr8c5R79+/bjmmmv+dP17773H1KlTmTFjBnvttRcnn3xyviujzYxmzZrRrFkzWrVqRadOnf5SCNwutm12zu1yaubBBx/M4sWLWbhwIRkZGXle06FDBx577LG/3L/33nsX+PGcc7Rq1Yrnn38+z5+JnJmfXdOiBTzxBLzxBowf74vCiBFQujT87W++C+nEE/1ePJUqRZPFOVi2zK/6nT4d3n8fvvrKf69mTbj2WrjoIj8WUEKNdYmfuP/fZmYlgaHAWUBD4BIzaxjvHMWpUqVKDBkyhAEDBrBt2zbOPPNMRo0axfr16wH48ccfWblyJWvWrKFy5crstddeLFiwgE8++WS3v3f58uXMnj37j9tz5syhdu3aAFSsWJF169YB0KBBA5YsWcK3334LwNixY2nZsiUNGjRg+fLlzJw5E4B169axPdYtUrt2bV5++WUuv/xyvsp5MSqiM84440+F47fffqN58+Z89NFHf2TbuHEjCxcuLJbH22Ply/v9+CdM8Ntcv/OO727ZsMFvXdG6tW9J1K3rt8LOOUHtnXfgu+9g7dr8z0rIzobVq/0L/OTJfjV0r15wyil+rn/t2n7R1zPPQI0avvvqyy/94O+gQX5dgIqAxFmIFkEz4Fvn3GIAM3sBOBf4OkCWYnPUUUdx5JFH8sILL9CxY0fmz5//RxdIhQoVePbZZ2ndujXDhw+nSZMm1K9fn+b57AS5bds2+vbty/LlyylXrhzVq1dn+PDhAFx55ZV069btj8Hi0aNH0759+z8Gi7t160aZMmUYP348PXv2ZNOmTZQvX56pU6f+8fvr16/PuHHjaN++PZMmTaJevXpF+hvcfvvtdO/enUaNGlGyZEn69+/P+eefz5gxY7jkkkvYEjub+J577vmjGyuYMmX8i/Mpp/jb69fDp5/6KZpz58IXX/gjHLOz//xzpUtD5cpQtqz/KFkStm3zXTqbNsFvv/21WOy9NzRq5AesmzTxLZDGjaFUavTMSvKzXXUrRPaAZhcCrZ1zXWK3OwLHOed67HRdV6ArQK1atY5ZunTpn37P/PnzOfzww+MTWiKXkP89t2/3ffiLF/sundWrfUvit9/8C/+WLf6aMmV8UShXzr/rr1oVqlf3U1sPPth/rVXwEoCZzXLO5d33m0uItyR5/Yv4SzVyzo0ARgBkZGTo7EKJv1Kl/It5nTqhk4hEKkRnZBZQM9ftg4BoV0aJiMguhSgEM4FDzayumZUBLgZeK8wvine3lkRD/x1Fwop7IXDObQd6AG8B84EJzrk9nrZSrlw5fvnlF72IJLmc8wjKlSsXOopI2goybcE59wbwRlF+x0EHHURWVharVq0qplQSSs4JZSISRtLOXytdunTeJ1qJiMge0coVEZE0p0IgIpLmVAhERNJc3FcWF4aZrQKW5nth4qkG/PXYsdSVbs8X9JzTRbI+59rOuer5XZQUhSBZmVlmQZZ3p4p0e76g55wuUv05q2tIRCTNqRCIiKQ5FYJojQgdIM7S7fmCnnO6SOnnrDECEZE0pxaBiEiaUyEQEUlzKgRxYGZ9zcyZWbXQWaJmZg+Z2QIzm2tmE81s39CZomJmrc3sGzP71sxuCZ0namZW08zeNbP5ZvaVmfUKnSkezKykmX1uZpNDZ4mKCkHEzKwm0Ar4IXSWOHkbaOScawIsBPoFzhMJMysJDAXOAhoCl5hZw7CpIrcd6OOcOxxoDnRPg+cM0Au/ZX7KUiGI3iDgn+RxHGcqcs5NiZ05AfAJ/gS6VNQM+NY5t9g5txV4ATg3cKZIOedWOOdmx75eh39xrBE2VbTM7CDg78BTobNESYUgQmZ2DvCjc+6L0FkC6Qz8N3SIiNQAluW6nUWKvyjmZmZ1gKOAT8Mmidwj+DdyO0IHiVLSnkeQKMxsKnBAHt+6DbgVOCO+iaK3u+fsnHs1ds1t+K6EcfHMFkeWx31p0eozswrAS0Bv59za0HmiYmZtgZXOuVlmdnLoPFFSISgi59zped1vZo2BusAXZga+i2S2mTVzzv0Ux4jFblfPOYeZXQG0BU5zqbtQJQuomev2QcDyQFnixsxK44vAOOfcy6HzROwE4BwzawOUA/Yxs2edc5cFzlXstKAsTsxsCZDhnEvGHQwLzMxaAwOBls65lD1H1MxK4QfDTwN+BGYC/yjM+dvJwvw7mqeBX51zvUPniadYi6Cvc65t6CxR0BiBFLfHgIrA22Y2x8yGhw4UhdiAeA/gLfyg6YRULgIxJwAdgVNj/23nxN4tS5JTi0BEJM2pRSAikuZUCERE0pwKgYhImlMhEBFJcyoEIiJpTgvKRHIxs6rAtNjNA4BsIGc9xEbn3N+CBBOJkKaPiuyCmd0JrHfODQidRSRK6hoSKSAzWx/7fLKZvW9mE8xsoZndb2aXmtlnZjbPzOrFrqtuZi+Z2czYxwlhn4FI3lQIRArnSPw+9Y3xq20Pc841w29X3DN2zWBgkHPuWOACUnwrY0leGiMQKZyZzrkVAGb2HTAldv884JTY16cDDWObDoLftKxibC9/kYShQiBSOFtyfb0j1+0d/O/fVQngeOfcpngGE9lT6hoSic4U/MZ0AJhZ04BZRHZJhUAkOtcDGWY218y+BrqFDiSSF00fFRFJc2oRiIikORUCEZE0p0IgIpLmVAhERNKcCoGISJpTIRARSXMqBCIiae7/AzTZwvREOkEVAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# visualising the results\n",
    "%matplotlib inline\n",
    "plt.plot(training_set_ori.loc[:,['x']], training_set_ori.loc[:,['y']],color = 'red', label = 'Real Stock Price')\n",
    "plt.title('Stock Price')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Price')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input/ouputs of recurrent neural (input != date, stock price)\n",
    "# but stock price at time t for input, and stock price t+1 for the output\n",
    "# create a set only with the \"Open\" Google stock price, extract that column\n",
    "# two-dimensional numpy array\n",
    "training_set_ori = training_set_ori.loc[:,['open']].values\n",
    "print(training_set_ori.shape,type(training_set_ori))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Feature Scaling + Normalization, since LSTM Several Sigmoid Activation function\n",
    "# Sigmoid 0 and 1, as is the case in Normalization\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "sc = MinMaxScaler() # default is 0,1\n",
    "# Fitting to training_set, scale training set, \n",
    "# transform we'll apply normalizationjust need min and max for normalization\n",
    "training_set = sc.fit_transform(training_set_ori)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(training_set.shape,type(training_set))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(training_set, color = 'red', label = 'Normalized Stock Price')\n",
    "plt.title('Normalized Stock Price')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Normalized Price')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Getting the inputs and the outputs, y_train is output, x_train is the input\n",
    "m = training_set.shape[0]\n",
    "X_train = training_set[0:m-1]\n",
    "y_train = training_set[1:m]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reshaping inputs, input has a certain format (2D array, features)\n",
    "# Changing the format of X_train into a 3D array, with a timestep\n",
    "# Keras Documentation - why reshape? - 3D tensor with shape (batch_size, timesteps)\n",
    "# time steps different between output and input time, input_dim dimension of input feature\n",
    "X_train = np.reshape(X_train, (m-1, 1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Part 2 - Building the RNN\n",
    "\n",
    "# Importing the keras libs and packages\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.layers import LSTM\n",
    "\n",
    "# Initialising the RNN\n",
    "# predicting a continuous outcome, regression model\n",
    "regressor = Sequential()\n",
    "\n",
    "# Adding the input layer and the LSTM layer\n",
    "regressor.add(LSTM(units = 4, activation = 'sigmoid', input_shape = (None, 1)))\n",
    "\n",
    "# Adding the output layer\n",
    "regressor.add(Dense(units = 1, activation = 'linear'))\n",
    "\n",
    "# Compiling the RNN\n",
    "# use the mean square error\n",
    "# regression won't be binary cross entropy, MSE for regression\n",
    "regressor.compile(optimizer = 'adam', loss = 'mean_squared_error')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fitting the RNN to the Trainign set\n",
    "regressor.fit(X_train, y_train, batch_size = 32, epochs = 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Part 3 - Making the predictions and visualising the results based on the training data\n",
    "\n",
    "# Get the real stock price 2016 - 2018\n",
    "# Importing the training set\n",
    "real_stock_price = pd.read_csv(\"zgpa_2016-18.csv\")\n",
    "real_stock_price_train = real_stock_price.iloc[:,1:2].values\n",
    "# Getting the predicted stock price of 2016 - 2018\n",
    "predicted_stock_price_train = regressor.predict(X_train)\n",
    "print(predicted_stock_price_train.shape)\n",
    "predicted_stock_price_train = sc.inverse_transform(predicted_stock_price_train)\n",
    "# visualising the results\n",
    "plt.plot(real_stock_price_train[1:m], color = 'red', label = 'Real Stock Price')\n",
    "plt.plot(predicted_stock_price_train, color = 'blue', label = 'Predicted Stock Price')\n",
    "plt.title('Stock Price Prediction')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Price')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualising the results\n",
    "plt.plot(real_stock_price_train[1:731], color = 'red', label = 'Real Stock Price')\n",
    "plt.plot(predicted_stock_price_train+10, color = 'blue', label = 'Predicted Stock Price+10')\n",
    "plt.title('Stock Price Prediction')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Price')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = np.concatenate((real_stock_price_train[1:1258],predicted_stock_price_train),axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = pd.DataFrame(result,columns=['real_price','predicted_price'])\n",
    "result.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.to_csv('result.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Part 4 - Making the predictions and visualising the results based on the test data\n",
    "%matplotlib inline\n",
    "# Getting the real stock price of 201\n",
    "test_set = pd.read_csv(\"zgpa_2019.csv\")\n",
    "test_set.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_stock_price = test_set.loc[:,['open']].values\n",
    "print(real_stock_price.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Getting the predicted stock price of 2019\n",
    "m_test = real_stock_price.shape[0]\n",
    "X_test_ori = real_stock_price[0:m_test-1]\n",
    "y_test_ori = real_stock_price[1:m]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_test_ori.shape,y_test_ori.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = sc.transform(X_test_ori)\n",
    "X_test = np.reshape(X_test, (m_test-1, 1, 1))\n",
    "predicted_stock_price = regressor.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_stock_price = sc.inverse_transform(predicted_stock_price)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualising the results\n",
    "plt.plot(y_test_ori, color = 'red', label = 'Real Stock Price')\n",
    "plt.plot(predicted_stock_price, color = 'blue', label = 'Predicted Stock Price')\n",
    "plt.title('Stock Price Prediction')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Price')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Part 5 - Evaluating the RNN\n",
    "# Evaluate of the RNN - learning to evaluate regression models\n",
    "# Root Mean Square Error (RMSE)\n",
    "\n",
    "import math\n",
    "from sklearn.metrics import mean_squared_error\n",
    "rmse = math.sqrt(mean_squared_error(y_test_ori, predicted_stock_price))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import load_model\n",
    "regressor.save('rnn_zgpa_stock_predict.h5')  # creates a HDF5 file 'my_model.h5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = np.concatenate((y_test_ori,predicted_stock_price),axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = pd.DataFrame(result,columns=['real_price','predicted_price'])\n",
    "result.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.to_csv('zgpa_predicted.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 股价预测对比\n",
    "\n",
    "![result_comp](images/result_comp.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 输出层激活函数未sigmoid/tanh的预测结果\n",
    "\n",
    "![code_update](images/code_update.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 输出层激活函数未sigmoid/tanh的预测结果\n",
    "\n",
    "![sigmoid_output](images/sigmoid_output.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**预测准确率的局限性:** 无法真实反映模型针对各个分类的预测准确度\n",
    "\n",
    "任务：计算并对比预测模型预测准确率、空准确率\n",
    "\n",
    "空准确率：当模型总是预测比例较高的类别，其预测准确率的数值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**结论:**\n",
    "\n",
    "  分类准确率可以方便的用于衡量模型的整体预测效果，但无法反应细节信息，具体表现在：\n",
    "- 没有体现数据的**实际分布情况**\n",
    "- 没有体现模型**错误预测的类型**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 混淆矩阵\n",
    "\n",
    "又称为误差矩阵，用于衡量分类算法的准确程度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![confusion_matrix](images/06_confusion_matrix.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**名词解释**\n",
    "\n",
    "- **True Positives (TP):** 预测准确、实际为正样本的数量（实际为1，预测为1）\n",
    "- **True Negatives (TN):** 预测准确、实际为负样本的数量（实际为0，预测为0）\n",
    "- **False Positives (FP):** 预测错误、实际为负样本的数量（实际为0，预测为1）\n",
    "- **False Negatives (FN):** 预测错误、实际为正样本的数量（实际为1，预测为0）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Small confusion matrix](images/09_confusion_matrix_1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Large confusion matrix](images/09_confusion_matrix_2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 混淆矩阵指标"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**准确率:** 整体样本中，预测正确样本数的比例\n",
    "- Accuracy = (TP + TN)/(TP + TN + FP + FN)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**精确率:** 预测结果为正的样本中，预测正确的比例\n",
    "- Precision = TP/(TP + FP)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**F1分数:** 综合Precision和Recall的一个判断指标\n",
    "- F1 Score = 2*Precision X Recall/(Precision + Recall)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![confusion matrix_metrics](images/09_confusion_matrix_3.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**结论:**\n",
    "\n",
    "- 分类任务中，相比单一的预测准确率，混淆矩阵提供了**更全面的模型评估信息**\n",
    "- 通过混淆矩阵，我们可以计算出**多样的模型表现衡量指标**，从而更好地选择模型\n",
    "\n",
    "**哪个衡量指标更关键?**\n",
    "\n",
    "- 衡量指标的选择取决于**应用场景**\n",
    "- **垃圾邮件检测** (正样本为 \"垃圾邮件\"): 希望普通邮件（负样本）不要被判断为垃圾邮件（正样本），需要关注**精确率**，希望判断为垃圾邮件的样本都是判断正确的；还需要关注**召回率**，希望所有的垃圾邮件尽可能被判断出来)\n",
    "- **异常交易检测** (正样本为 \"异常交易\"): 希望所有的异常交易都被检测到，即判断为正常的交易中尽可能不存在异常交易，需要关注**特异度**"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
